Support Whisper training with Google Cloud buckets#70
Conversation
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
jqug
left a comment
There was a problem hiding this comment.
Thanks for this, looks good.
Just one thing, let's take out the gcloud auth for now and maybe mention in a comment in the file that this may be necessary.
|
We should consider merging this notebook into the dedicated sunbird-speech repo: |
…ngual_eval_fn processing; fix label 448 limit; launch full training
… fix preprocess error
jqug
left a comment
There was a problem hiding this comment.
Thanks, LGTM
I double checked the language token IDs comparing with the Whisper tokenizer, and they look right. Actually I didn't realise that Whisper supports so many African languages already :)
| "target": row.get("text"), | ||
| "target.language": row.get("language"), | ||
| } | ||
| yield example |
There was a problem hiding this comment.
This looks good. A further improvement for later, in case it's an ASR/audio dataset and the format already matches, is not to use a generator at all - we just load the huggingface datasets and concatenate them. That should reduce CPU bottleneck and could improve GPU utilisation.
This PR made the following changes to the salt library for running latest whisper finetuning :
skip_matching_asrargument.SALT_LANGUAGE_TOKENS_WHISPERinconstants.pywith 51 African languages for new whisper-salt ASR model training.multilingual_eval_fninmetrics.pyby skipping unnecessary heavy CPU audio decoding process.augment_audio_noisefunction inpreprocessing.pythat makes the output audio to be zero-size.The whisper finetuning/training scripts and configs has been moved
sunbird-speechrepo underspeech-to-text/whisperdirectory.[Depreciated]
This PR adds the support of google cloud buckets for the whisper training pipeline, and made several other changes:
gcs://path withdatasets.load_datasetand cast the audio column todatasets.Audioformat.salt.datasetsfrom the current repo instead of https://github.com/jqug/salt.gitgradient_checkpointing=Falsetorch_dtype=torch.float32when loading the model weightsmodel.generation_configbased on requirements from the new version.Overfit experiment
An overfit experiment with just 100 examples was done to verify the changes:
MLflow run1 with evaluation metrics: https://mlflow-sunbird-ce0ecfc14244.herokuapp.com/#/experiments/0/runs/2d488acdc39146e9af9da07c00128d49/model-metrics
MLfLow run2 with GPU utilization: https://mlflow.sunbird.ai/#/experiments/0/runs/811bbdf051f44597bd90c3376cfc9309/system-metrics
TODO
salt.constants.SALT_LANGUAGE_TOKENS_WHISPERto support new languages. Currently we only have the following: